TRPO (Trust Region Policy Optimization) — low-level PyTorch implementation#

TRPO is an on-policy policy-gradient method that makes monotonic-ish, stable updates by constraining how much the policy is allowed to change each iteration via a KL-divergence trust region.

In this notebook you will:

  • Derive the KL constraint (LaTeX) and how it leads to a natural-gradient step

  • Implement TRPO “from scratch” with PyTorch autograd + conjugate gradient + backtracking line search

  • Visualize policy updates, KL per update, and episodic returns with Plotly

  • See a reference Stable-Baselines TRPO implementation and understand its hyperparameters

Notebook roadmap#

  1. TRPO objective + the KL-divergence constraint (math)

  2. A tiny offline-friendly continuous-control environment (no downloads)

  3. Gaussian policy + value baseline (PyTorch)

  4. GAE advantages + value function fit

  5. TRPO update step (Fisher-vector product, conjugate gradient, line search)

  6. Plotly: episodic rewards, KL constraint, policy update snapshots

  7. Stable-Baselines TRPO: usage + hyperparameters (end)

import sys
import time

import numpy as np
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots

import torch
import torch.nn as nn
import torch.nn.functional as F

pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")

np.set_printoptions(precision=4, suppress=True)

DEVICE = torch.device("cpu")

SEED = 42
rng = np.random.default_rng(SEED)
torch.manual_seed(SEED)
<torch._C.Generator at 0x75de937b1190>
print("Python:", sys.version.split()[0])
print("NumPy:", np.__version__)
import plotly

print("Plotly:", plotly.__version__)
print("PyTorch:", torch.__version__)
print("Device:", DEVICE)
Python: 3.12.9
NumPy: 1.26.2
Plotly: 6.5.2
PyTorch: 2.7.0+cu126
Device: cpu

1) TRPO objective and the KL-divergence constraint#

TRPO is usually presented as the constrained optimization problem:

[ \max_\theta; \mathbb{E}{s,a\sim \pi{\theta_{\text{old}}}}\left[\frac{\pi_\theta(a\mid s)}{\pi_{\theta_{\text{old}}}(a\mid s)},\hat A_{\theta_{\text{old}}}(s,a)\right] \qquad\text{s.t.}\qquad \mathbb{E}{s\sim \pi{\theta_{\text{old}}}}\left[D_{\mathrm{KL}}!\left(\pi_{\theta_{\text{old}}}(\cdot\mid s),|,\pi_\theta(\cdot\mid s)\right)\right] \le \delta. ]

The trust region is average KL divergence (under states visited by the old policy). Intuition: “move in a direction that increases the objective, but don’t move too far in policy space.”

We use the standard definition:

[ D_{\mathrm{KL}}(p|q) = \mathbb{E}_{x\sim p}\left[\log\frac{p(x)}{q(x)}\right]. ]

1.1) Why this leads to a natural-gradient step#

Let (\theta) be the policy parameters and (\theta_{\text{old}}) the pre-update parameters.

TRPO uses two approximations around (\theta_{\text{old}}):

  • First-order (linear) approximation of the surrogate objective:

[ L(\theta) \approx L(\theta_{\text{old}}) + g^\top (\theta - \theta_{\text{old}}) \quad\text{where}\quad g = \nabla_\theta L(\theta)\big\rvert_{\theta=\theta_{\text{old}}}. ]

  • Second-order (quadratic) approximation of the KL constraint:

[ \bar D_{\mathrm{KL}}(\theta_{\text{old}},\theta) \approx \tfrac12 (\theta - \theta_{\text{old}})^\top H (\theta - \theta_{\text{old}}), ]

where (H) is the Hessian of the average KL at (\theta_{\text{old}}) (equivalently, the policy’s Fisher information matrix for common exponential-family policies).

Define the step (p = \theta - \theta_{\text{old}}). The constrained problem becomes:

[ \max_p; g^\top p \qquad\text{s.t.}\qquad \tfrac12 p^\top H p \le \delta. ]

The solution is:

[ p^* = \sqrt{\frac{2\delta}{g^\top H^{-1} g}}; H^{-1} g. ]

So we need:

  1. The policy-gradient (g)

  2. The product (H^{-1} g) (without forming (H) explicitly) → conjugate gradient + Hessian-vector products

  3. A step scaling + backtracking line search to satisfy the true KL constraint and improve the surrogate.

2) A tiny offline-friendly continuous-control environment#

To keep the notebook self-contained (no Gym downloads), we use a 1D point-mass with state (s=(x,v)) and action (a\in[-1,1]):

  • Dynamics: small acceleration changes velocity, velocity changes position

  • Goal: reach (x=0) with small velocity

  • Reward: negative quadratic cost (plus a small terminal bonus when reaching the goal)

This is not meant to be a benchmark; it’s just enough to show that TRPO learns and that the KL trust region stabilizes updates.

class PointMass1DEnv:
    def __init__(
        self,
        dt: float = 0.05,
        max_steps: int = 150,
        x_init_range: float = 2.0,
        v_init_range: float = 0.5,
        action_max: float = 1.0,
        goal_x: float = 0.0,
        goal_tol: float = 0.05,
        goal_bonus: float = 5.0,
        seed: int | None = None,
    ):
        self.dt = float(dt)
        self.max_steps = int(max_steps)
        self.x_init_range = float(x_init_range)
        self.v_init_range = float(v_init_range)
        self.action_max = float(action_max)
        self.goal_x = float(goal_x)
        self.goal_tol = float(goal_tol)
        self.goal_bonus = float(goal_bonus)
        self.rng = np.random.default_rng(seed)

        self.steps = 0
        self.x = 0.0
        self.v = 0.0

    @property
    def obs_dim(self):
        return 2

    @property
    def act_dim(self):
        return 1

    def reset(self, seed: int | None = None):
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        self.steps = 0
        self.x = self.rng.uniform(-self.x_init_range, self.x_init_range)
        self.v = self.rng.uniform(-self.v_init_range, self.v_init_range)
        return np.array([self.x, self.v], dtype=np.float32)

    def step(self, action):
        a = float(np.clip(action, -self.action_max, self.action_max))

        # simple damped dynamics
        self.v = 0.99 * self.v + a * self.dt
        self.x = self.x + self.v * self.dt
        self.steps += 1

        # quadratic cost around the goal
        cost = (self.x - self.goal_x) ** 2 + 0.1 * (self.v**2) + 0.001 * (a**2)
        reward = -float(cost)

        done = False
        if abs(self.x - self.goal_x) < self.goal_tol and abs(self.v) < self.goal_tol:
            done = True
            reward += float(self.goal_bonus)
        if self.steps >= self.max_steps:
            done = True

        obs = np.array([self.x, self.v], dtype=np.float32)
        return obs, reward, done, {}
env = PointMass1DEnv(seed=SEED)
obs = env.reset()

xs, vs, acts, rews = [obs[0]], [obs[1]], [], []
done = False
while not done:
    a = rng.uniform(-1.0, 1.0)
    obs, r, done, _ = env.step(a)
    xs.append(obs[0])
    vs.append(obs[1])
    acts.append(a)
    rews.append(r)

fig = make_subplots(rows=3, cols=1, shared_xaxes=True)
fig.add_trace(go.Scatter(y=xs, mode="lines", name="x"), row=1, col=1)
fig.add_trace(go.Scatter(y=vs, mode="lines", name="v"), row=2, col=1)
fig.add_trace(go.Scatter(y=acts, mode="lines", name="a"), row=3, col=1)
fig.update_layout(
    title="One random rollout in the toy env",
    height=650,
    showlegend=True,
)
fig.update_yaxes(title_text="position x", row=1, col=1)
fig.update_yaxes(title_text="velocity v", row=2, col=1)
fig.update_yaxes(title_text="action a", row=3, col=1)
fig.update_xaxes(title_text="time step", row=3, col=1)
fig.show()

print("Return (sum reward):", float(np.sum(rews)))
Return (sum reward): -206.57660868987693

3) Policy and value function (PyTorch)#

We’ll use:

  • A Gaussian policy (\pi_\theta(a\mid s)=\mathcal{N}(\mu_\theta(s),\sigma_\theta(s)^2)) with diagonal covariance (here 1D)

  • A value network (V_\phi(s)) as a baseline

For TRPO we need:

  • (\log \pi_\theta(a\mid s)) to compute the surrogate objective

  • The KL between old and new Gaussian policies to build the trust region (and its Hessian-vector product)

def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
    layers = []
    for i in range(len(sizes) - 1):
        act = activation if i < len(sizes) - 2 else output_activation
        layers.append(nn.Linear(sizes[i], sizes[i + 1]))
        layers.append(act())
    return nn.Sequential(*layers)


class GaussianPolicy(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes=(64, 64)):
        super().__init__()
        self.net = mlp([obs_dim, *hidden_sizes, act_dim], activation=nn.Tanh)
        self.log_std = nn.Parameter(torch.zeros(act_dim))

    def forward(self, obs: torch.Tensor):
        mean = self.net(obs)
        log_std = self.log_std.expand_as(mean)
        return mean, log_std

    def dist(self, obs: torch.Tensor):
        mean, log_std = self.forward(obs)
        return torch.distributions.Normal(mean, torch.exp(log_std))

    @torch.no_grad()
    def act(self, obs: torch.Tensor):
        dist = self.dist(obs)
        action = dist.sample()
        logp = dist.log_prob(action).sum(-1)
        return action, logp


class ValueNet(nn.Module):
    def __init__(self, obs_dim: int, hidden_sizes=(64, 64)):
        super().__init__()
        self.net = mlp([obs_dim, *hidden_sizes, 1], activation=nn.Tanh)

    def forward(self, obs: torch.Tensor):
        return self.net(obs).squeeze(-1)

4) TRPO building blocks (low-level)#

We implement:

  • GAE((\gamma,\lambda)) for advantages

  • Value function regression

  • Conjugate gradient for solving (H x = g)

  • Fisher/Hessian-vector product via autograd on the mean KL

  • Backtracking line search enforcing the KL constraint

def gaussian_kl(mean_old, log_std_old, mean_new, log_std_new):
    """KL( N_old || N_new ) for diagonal Gaussians; returns shape (batch,)."""
    var_old = torch.exp(2.0 * log_std_old)
    var_new = torch.exp(2.0 * log_std_new)
    kl_per_dim = (
        log_std_new
        - log_std_old
        + (var_old + (mean_old - mean_new) ** 2) / (2.0 * var_new)
        - 0.5
    )
    return kl_per_dim.sum(dim=-1)


def flat_params(model: nn.Module):
    return torch.cat([p.data.view(-1) for p in model.parameters()])


def set_flat_params(model: nn.Module, flat: torch.Tensor):
    idx = 0
    with torch.no_grad():
        for p in model.parameters():
            n = p.numel()
            p.copy_(flat[idx : idx + n].view_as(p))
            idx += n


def flat_grad(grads, params):
    out = []
    for g, p in zip(grads, params):
        if g is None:
            out.append(torch.zeros_like(p).view(-1))
        else:
            out.append(g.contiguous().view(-1))
    return torch.cat(out)


def conjugate_gradient(fvp_fn, b, cg_iters=10, residual_tol=1e-10):
    x = torch.zeros_like(b)
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)

    for _ in range(cg_iters):
        Avp = fvp_fn(p)
        alpha = rdotr / (torch.dot(p, Avp) + 1e-8)
        x = x + alpha * p
        r = r - alpha * Avp
        new_rdotr = torch.dot(r, r)
        if new_rdotr < residual_tol:
            break
        beta = new_rdotr / (rdotr + 1e-8)
        p = r + beta * p
        rdotr = new_rdotr

    return x


def trpo_update(
    policy: GaussianPolicy,
    obs: torch.Tensor,
    act: torch.Tensor,
    adv: torch.Tensor,
    logp_old: torch.Tensor,
    max_kl: float = 0.01,
    cg_iters: int = 10,
    cg_damping: float = 1e-2,
    backtrack_iters: int = 10,
    backtrack_coeff: float = 0.8,
):
    """One TRPO policy update step."""

    params = list(policy.parameters())
    old_params = flat_params(policy)

    with torch.no_grad():
        mean_old, log_std_old = policy.forward(obs)
        mean_old = mean_old.detach()
        log_std_old = log_std_old.detach()

    def surrogate():
        dist = policy.dist(obs)
        logp = dist.log_prob(act).sum(-1)
        ratio = torch.exp(logp - logp_old)
        return (ratio * adv).mean()

    def mean_kl():
        mean_new, log_std_new = policy.forward(obs)
        return gaussian_kl(mean_old, log_std_old, mean_new, log_std_new).mean()

    surr = surrogate()
    g = torch.autograd.grad(surr, params, retain_graph=True, allow_unused=True)
    g_flat = flat_grad(g, params).detach()

    def fvp(v):
        kl = mean_kl()
        grads = torch.autograd.grad(kl, params, create_graph=True, allow_unused=True)
        flat_kl_grad = flat_grad(grads, params)
        kl_v = torch.dot(flat_kl_grad, v)
        grads2 = torch.autograd.grad(kl_v, params, allow_unused=True)
        hvp = flat_grad(grads2, params).detach()
        return hvp + cg_damping * v

    step_dir = conjugate_gradient(fvp, g_flat, cg_iters=cg_iters)
    shs = torch.dot(step_dir, fvp(step_dir))
    step_size = torch.sqrt(torch.tensor(2.0 * max_kl, dtype=shs.dtype) / (shs + 1e-8))
    full_step = step_dir * step_size

    def eval_surr_and_kl():
        with torch.no_grad():
            s = surrogate().item()
            k = mean_kl().item()
        return s, k

    surr_old_val, _ = eval_surr_and_kl()

    step_frac = 1.0
    accepted = False
    surr_new_val = surr_old_val
    kl_new_val = 0.0

    for _ in range(backtrack_iters):
        new_params = old_params + step_frac * full_step
        set_flat_params(policy, new_params)

        surr_new_val, kl_new_val = eval_surr_and_kl()

        if (surr_new_val > surr_old_val) and (kl_new_val <= max_kl):
            accepted = True
            break
        step_frac *= backtrack_coeff

    if not accepted:
        set_flat_params(policy, old_params)

    return {
        "surr_old": float(surr_old_val),
        "surr_new": float(surr_new_val),
        "kl": float(kl_new_val),
        "step_frac": float(step_frac if accepted else 0.0),
        "accepted": bool(accepted),
    }
def collect_batch(env, policy, value_net, steps_per_batch, gamma=0.99, lam=0.98):
    obs_buf = np.zeros((steps_per_batch, env.obs_dim), dtype=np.float32)
    act_buf = np.zeros((steps_per_batch, env.act_dim), dtype=np.float32)
    rew_buf = np.zeros(steps_per_batch, dtype=np.float32)
    done_buf = np.zeros(steps_per_batch, dtype=np.float32)
    val_buf = np.zeros(steps_per_batch, dtype=np.float32)
    logp_buf = np.zeros(steps_per_batch, dtype=np.float32)

    ep_returns = []
    ep_ret = 0.0

    obs = env.reset()

    for t in range(steps_per_batch):
        obs_t = torch.as_tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        with torch.no_grad():
            a_t, logp_t = policy.act(obs_t)
            v_t = value_net(obs_t)

        a = a_t.squeeze(0).cpu().numpy()
        logp = float(logp_t.item())
        v = float(v_t.item())

        next_obs, r, done, _ = env.step(a)

        obs_buf[t] = obs
        act_buf[t] = a
        rew_buf[t] = r
        done_buf[t] = float(done)
        val_buf[t] = v
        logp_buf[t] = logp

        ep_ret += float(r)

        obs = next_obs
        if done:
            ep_returns.append(ep_ret)
            ep_ret = 0.0
            obs = env.reset()

    # bootstrap value for the last state (if last transition wasn't terminal)
    with torch.no_grad():
        last_val = value_net(
            torch.as_tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        ).item()

    adv_buf = np.zeros(steps_per_batch, dtype=np.float32)
    last_gae = 0.0

    for t in reversed(range(steps_per_batch)):
        if t == steps_per_batch - 1:
            next_nonterminal = 1.0 - done_buf[t]
            next_value = last_val
        else:
            next_nonterminal = 1.0 - done_buf[t]
            next_value = val_buf[t + 1]

        delta = rew_buf[t] + gamma * next_value * next_nonterminal - val_buf[t]
        last_gae = delta + gamma * lam * next_nonterminal * last_gae
        adv_buf[t] = last_gae

    ret_buf = adv_buf + val_buf

    # normalize advantages (very common and usually helpful)
    adv_buf = (adv_buf - adv_buf.mean()) / (adv_buf.std() + 1e-8)

    batch = {
        "obs": torch.as_tensor(obs_buf, dtype=torch.float32, device=DEVICE),
        "act": torch.as_tensor(act_buf, dtype=torch.float32, device=DEVICE),
        "logp_old": torch.as_tensor(logp_buf, dtype=torch.float32, device=DEVICE),
        "adv": torch.as_tensor(adv_buf, dtype=torch.float32, device=DEVICE),
        "ret": torch.as_tensor(ret_buf, dtype=torch.float32, device=DEVICE),
        "ep_returns": ep_returns,
    }
    return batch
# --- Run configuration ---
FAST_RUN = True  # set False for a longer run

TOTAL_ITERS = 25 if FAST_RUN else 150
STEPS_PER_BATCH = 1024 if FAST_RUN else 4096

GAMMA = 0.99
LAMBDA = 0.98

MAX_KL = 0.01
CG_ITERS = 10
CG_DAMPING = 1e-2
BACKTRACK_ITERS = 10
BACKTRACK_COEFF = 0.8

VF_LR = 3e-4
VF_ITERS = 10 if FAST_RUN else 80
VF_BATCH = 128

SNAPSHOT_EVERY = 5

env = PointMass1DEnv(seed=SEED)
policy = GaussianPolicy(env.obs_dim, env.act_dim, hidden_sizes=(64, 64)).to(DEVICE)
value_net = ValueNet(env.obs_dim, hidden_sizes=(64, 64)).to(DEVICE)

vf_optim = torch.optim.Adam(value_net.parameters(), lr=VF_LR)

x_grid = np.linspace(-env.x_init_range, env.x_init_range, 101, dtype=np.float32)
history = {
    "iter": [],
    "ep_ret_mean": [],
    "ep_ret_p10": [],
    "ep_ret_p90": [],
    "kl": [],
    "surr_old": [],
    "surr_new": [],
    "step_frac": [],
    "policy_std": [],
}

policy_snapshots = []

t0 = time.time()

for it in range(TOTAL_ITERS):
    batch = collect_batch(
        env,
        policy,
        value_net,
        steps_per_batch=STEPS_PER_BATCH,
        gamma=GAMMA,
        lam=LAMBDA,
    )

    # --- Fit value function ---
    for _ in range(VF_ITERS):
        n = batch["obs"].shape[0]
        bs = min(VF_BATCH, n)
        idx = torch.as_tensor(rng.choice(n, size=bs, replace=False), device=DEVICE)
        v_pred = value_net(batch["obs"][idx])
        v_loss = F.mse_loss(v_pred, batch["ret"][idx])
        vf_optim.zero_grad()
        v_loss.backward()
        vf_optim.step()

    # --- TRPO policy update ---
    stats = trpo_update(
        policy,
        obs=batch["obs"],
        act=batch["act"],
        adv=batch["adv"],
        logp_old=batch["logp_old"],
        max_kl=MAX_KL,
        cg_iters=CG_ITERS,
        cg_damping=CG_DAMPING,
        backtrack_iters=BACKTRACK_ITERS,
        backtrack_coeff=BACKTRACK_COEFF,
    )

    # --- Metrics ---
    ep_returns = batch["ep_returns"]
    if len(ep_returns) > 0:
        ep_mean = float(np.mean(ep_returns))
        ep_p10 = float(np.percentile(ep_returns, 10))
        ep_p90 = float(np.percentile(ep_returns, 90))
    else:
        ep_mean, ep_p10, ep_p90 = float("nan"), float("nan"), float("nan")

    with torch.no_grad():
        policy_std = float(torch.exp(policy.log_std).mean().item())

    history["iter"].append(it)
    history["ep_ret_mean"].append(ep_mean)
    history["ep_ret_p10"].append(ep_p10)
    history["ep_ret_p90"].append(ep_p90)
    history["kl"].append(stats["kl"])
    history["surr_old"].append(stats["surr_old"])
    history["surr_new"].append(stats["surr_new"])
    history["step_frac"].append(stats["step_frac"])
    history["policy_std"].append(policy_std)

    # snapshot policy mean(action|x,v=0) over a grid
    if (it == 0) or (it % SNAPSHOT_EVERY == 0) or (it == TOTAL_ITERS - 1):
        obs_grid = np.stack([x_grid, np.zeros_like(x_grid)], axis=1)
        with torch.no_grad():
            mu, _ = policy.forward(torch.as_tensor(obs_grid, dtype=torch.float32, device=DEVICE))
        policy_snapshots.append({"iter": it, "mu": mu.squeeze(-1).cpu().numpy()})

    if (it + 1) % max(1, TOTAL_ITERS // 5) == 0 or it == 0:
        print(
            f"iter {it:03d} | ep_ret_mean {ep_mean:8.2f} | KL {stats['kl']:.4f} | "
            f"step_frac {stats['step_frac']:.3f} | std {policy_std:.3f}"
        )

print(f"Done in {time.time() - t0:.2f}s")
/tmp/ipykernel_1016759/1044679479.py:45: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

/home/tempa/miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:824: UserWarning:

CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
iter 000 | ep_ret_mean -1083.82 | KL 0.0086 | step_frac 1.000 | std 0.997
iter 004 | ep_ret_mean  -333.98 | KL 0.0092 | step_frac 1.000 | std 0.948
iter 009 | ep_ret_mean  -117.50 | KL 0.0070 | step_frac 0.800 | std 0.940
iter 014 | ep_ret_mean -2175.23 | KL 0.0068 | step_frac 0.640 | std 1.018
iter 019 | ep_ret_mean -2436.67 | KL 0.0049 | step_frac 1.000 | std 1.005
iter 024 | ep_ret_mean   -63.65 | KL 0.0089 | step_frac 1.000 | std 1.044
Done in 3.64s
# Plotly: learning curves and trust-region diagnostics

iters = history["iter"]

fig = make_subplots(
    rows=3,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.08,
    subplot_titles=(
        "Episodic return (mean + 10/90 percentile band)",
        "Mean KL(old || new) per update (should be ≤ max_kl)",
        "Policy std (exp(log_std))",
    ),
)

# return band
fig.add_trace(
    go.Scatter(x=iters, y=history["ep_ret_p90"], mode="lines", line=dict(width=0), showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=iters,
        y=history["ep_ret_p10"],
        mode="lines",
        fill="tonexty",
        line=dict(width=0),
        name="p10–p90",
        opacity=0.25,
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(x=iters, y=history["ep_ret_mean"], mode="lines+markers", name="mean"),
    row=1,
    col=1,
)

# KL curve
fig.add_trace(
    go.Scatter(x=iters, y=history["kl"], mode="lines+markers", name="KL"),
    row=2,
    col=1,
)
fig.add_hline(y=MAX_KL, line_dash="dash", line_color="black", row=2, col=1)

# policy std
fig.add_trace(
    go.Scatter(x=iters, y=history["policy_std"], mode="lines+markers", name="std"),
    row=3,
    col=1,
)

fig.update_layout(height=850, title="TRPO learning diagnostics")
fig.update_xaxes(title_text="iteration", row=3, col=1)
fig.update_yaxes(title_text="return", row=1, col=1)
fig.update_yaxes(title_text="KL", row=2, col=1)
fig.update_yaxes(title_text="std", row=3, col=1)
fig.show()
# Plotly: how the policy mean changes over iterations

fig = go.Figure()
for snap in policy_snapshots:
    fig.add_trace(
        go.Scatter(
            x=x_grid,
            y=snap["mu"],
            mode="lines",
            name=f"iter {snap['iter']}",
        )
    )

fig.update_layout(
    title="Policy mean action μ(x, v=0) snapshots",
    xaxis_title="position x (with v fixed at 0)",
    yaxis_title="mean action μ",
    height=450,
)
fig.show()

5) Stable-Baselines TRPO (reference implementation)#

TRPO does exist in the original stable-baselines (TensorFlow) project via stable_baselines.trpo_mpi.TRPO (and is re-exported as stable_baselines.TRPO if mpi4py is installed).

Example usage (not executed here):

import gym

# Requires the original stable-baselines (TensorFlow) + mpi4py.
from stable_baselines import TRPO
from stable_baselines.common.policies import MlpPolicy

env = gym.make("CartPole-v1")
model = TRPO(
    MlpPolicy,
    env,
    gamma=0.99,
    timesteps_per_batch=1024,
    max_kl=0.01,
    cg_iters=10,
    lam=0.98,
    entcoeff=0.0,
    cg_damping=1e-2,
    vf_stepsize=3e-4,
    vf_iters=3,
    verbose=1,
)
model.learn(total_timesteps=200_000)

Source used to verify signature and defaults:

  • https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/trpo_mpi/trpo_mpi.py

Stable-Baselines TRPO hyperparameters (what they mean)#

From the upstream TRPO.__init__ signature:

  • gamma — discount factor (\gamma)

  • timesteps_per_batch — on-policy batch size (number of environment steps collected before each TRPO update)

  • max_kl — trust-region radius (\delta): target/upper bound on mean KL(old || new)

  • cg_iters — number of conjugate-gradient iterations used to approximately solve (H x = g)

  • lam — GAE parameter (\lambda) controlling bias/variance tradeoff in advantages

  • entcoeff — entropy bonus coefficient (encourages exploration by penalizing low entropy)

  • cg_damping — adds a small multiple of the identity to the Fisher/Hessian-vector product for numerical stability

  • vf_stepsize — learning rate for the value function optimizer

  • vf_iters — number of value-function optimization iterations per update

  • tensorboard_log / full_tensorboard_log — logging configuration

  • policy_kwargs — extra arguments passed to the policy network constructor

  • seed — RNG seed

  • n_cpu_tf_sess — TensorFlow session CPU threading configuration

A good way to tune TRPO is to start with:

  • max_kl around 0.01 and adjust up/down for faster learning vs stability

  • timesteps_per_batch larger for smoother updates (at higher compute cost)

  • cg_damping slightly larger if updates become numerically unstable